Skip to content

Add AMD ROCm support and documentation.#272

Open
keybangz wants to merge 14 commits intonomadkaraoke:mainfrom
keybangz:amd-rocm
Open

Add AMD ROCm support and documentation.#272
keybangz wants to merge 14 commits intonomadkaraoke:mainfrom
keybangz:amd-rocm

Conversation

@keybangz
Copy link

@keybangz keybangz commented Mar 17, 2026

This pull request introduces comprehensive support for AMD GPUs using ROCm, updates documentation for multi-platform GPU acceleration, and refactors GPU device handling in separator architectures for improved clarity and compatibility. The changes include a new ROCm-specific Dockerfile, detailed ROCm installation and troubleshooting instructions in the README.md, and code updates to ensure correct device assignment and logging for both Demucs and MDXC separators.

Multi-platform GPU acceleration and ROCm support:

  • Added a new Dockerfile.rocm for building and running the application with ROCm and AMD GPU support, including necessary package installations and environment variable setup.
  • Expanded README.md with a new section for AMD GPU/ROCm installation, troubleshooting, and performance optimization, including Docker usage instructions and environment variable details. [1] [2]
  • Updated documentation to highlight multi-platform GPU acceleration, including ROCm, and clarified GPU-specific installation steps for both CUDA/NVIDIA and ROCm/AMD. [1] [2]

Device handling and refactoring in separator architectures:

  • Refactored demucs_separator.py to use an inference_device variable for model loading and inference, ensuring compatibility with ROCm and improving logging/debugging output. [1] [2] [3]
  • Refactored mdxc_separator.py to consistently use an inference_device variable for model loading, improved debug logging, and clarified device assignment for both Roformer and TFC_TDF_net models.

Code formatting and logging improvements:

  • Reformatted code and expanded debug logging in both demucs_separator.py and mdxc_separator.py for better readability and traceability of parameters and device usage. [1] [2] [3] [4] [5] [6]

Documentation structure enhancements:

  • Updated README.md navigation and section headers to include ROCm-specific topics and troubleshooting, improving clarity for users with AMD GPUs.

CUDA/ROCm troubleshooting guidance:

  • Added detailed troubleshooting steps for CUDA and ROCm, including common errors, library version requirements, and recommended commands for resolving GPU provider issues. [1] [2]

This branch will require extensive testing, and a decent polish up before merging.

Summary by CodeRabbit

  • New Features

    • Added comprehensive AMD ROCm support (GPU orchestration, ROCm-aware models, and ROCm-optimized separation model)
    • New ROCm-enabled Docker image and packaging extras for optional ROCm runtime
  • Documentation

    • Added ROCm installation, Docker usage, troubleshooting, and GPU-acceleration guidance (parity with CUDA)
  • Tools

    • Added ROCm diagnostic and minimal debug scripts
  • Chores

    • Updated ignore list to exclude debug output file

- Add ROCm execution provider detection in separator.py
- Add configure_rocm method for AMD GPU configuration
- Update log_onnxruntime_packages to include ROCm package
- Add rocm extra to pyproject.toml dependencies
- Create Dockerfile.rocm for ROCm-based containers
- Update README with ROCm installation instructions and troubleshooting
- ROCm detection prioritizes ROCMExecutionProvider when available
- Maintains backward compatibility with existing CUDA/NVIDIA support
- Add detailed ROCm installation instructions with exact PyTorch commands
- Add PyTorch version detection in device setup
- Add warning when ROCm ExecutionProvider detected but PyTorch has CUDA support
- Improve documentation for proper ROCm setup process
- Maintain backward compatibility while providing better user guidance

This fixes the issue where 'audio-separator[rocm]' installs CPU/default PyTorch
instead of ROCm-compatible PyTorch, and helps users properly configure ROCm.
@coderabbitai
Copy link

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Adds comprehensive AMD ROCm support: ROCm Docker, dependency extras, ROCm-aware device orchestration, ROCm-optimized Demucs model, STFT/tensor device handling adjustments, README ROCm docs, and diagnostic scripts for ROCm validation.

Changes

Cohort / File(s) Summary
ROCm Docker & Packaging
Dockerfile.rocm, pyproject.toml, .gitignore
Adds a ROCm base Dockerfile, optional onnxruntime-rocm dependency and rocm extras, and ignores debug_results.txt.
Top-level docs & diagnostics
README.md, debug_rocm_issues.py, debug_minimal.py
Adds ROCm installation, troubleshooting, and Docker guidance in README; introduces comprehensive and minimal ROCm diagnostic scripts that collect system, PyTorch, ONNX Runtime, STFT, memory, and model load checks.
Separator core & orchestration
audio_separator/separator/separator.py, audio_separator/separator/common_separator.py
Introduces is_rocm state, configure_rocm(ort_providers) method, and ROCm-aware device selection and ONNX Runtime provider orchestration; plumbs ROCm flags into inference paths and logging.
Device/util helpers
audio_separator/separator/uvr_lib_v5/utils.py
Adds is_rocm() utility to detect ROCm in torch build/version.
Architectures: Demucs / MDX / VR
audio_separator/separator/architectures/demucs_separator.py, .../mdxc_separator.py, .../vr_separator.py
Refactors device handling to use explicit inference_device (method signature change in DemucsSeparator), reshapes source handling, and expands logging/formatting; mostly device plumbing and minor API signature update.
New ROCm-optimized Demucs
audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py
Adds HDemucsROCm with ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer and hybrid time–frequency processing, per-band wrapping, and Wiener masking for ROCm-friendly inference.
Roformer and Roformer loader
audio_separator/separator/roformer/roformer_loader.py, audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py, .../mel_band_roformer.py
Normalizes string quoting, stabilizes state-dict loading, updates DEFAULT_FREQS_PER_BANDS, adds device-awareness scaffolding (x_is_rocm) and refactors STFT/reconstruction handling for ROCm compatibility.
STFT / TFC-TDF internals
audio_separator/separator/uvr_lib_v5/stft.py, audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py
Adds run_on_cpu fallback for non-standard/ROCm devices, forces FP32 for stability on those paths, adds try/fallback for complex STFT returns, and ensures correct device transfers post-computation.
Common separator constants & audio I/O
audio_separator/separator/common_separator.py
Expands stem mappings, adds is_rocm flag to CommonSeparator init plumbing, improves bit-depth detection and audio write paths with consistent subtype handling and Roformer loader hooks.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI
    participant Separator
    participant CommonSeparator
    participant RoformerLoader
    participant ONNXRuntime
    participant PyTorch
    participant HDemucsROCm

    User->>CLI: request separate(audio, use_rocm=True)
    CLI->>Separator: initialize(is_rocm=True)
    Separator->>CommonSeparator: init(is_rocm=True)
    CommonSeparator->>RoformerLoader: _initialize_roformer_loader()
    RoformerLoader-->>CommonSeparator: loader or None

    Separator->>Separator: setup_torch_device()
    Separator->>PyTorch: set device -> cuda (ROCm)
    Separator->>ONNXRuntime: configure_rocm(ort_providers)
    ONNXRuntime-->>Separator: ROCmExecutionProvider configured

    Separator->>CommonSeparator: demix_demucs(mix, inference_device)
    CommonSeparator->>HDemucsROCm: load model on inference_device
    HDemucsROCm->>PyTorch: forward (hybrid TF)
    HDemucsROCm-->>CommonSeparator: separated sources
    CommonSeparator-->>Separator: return stems
    Separator-->>User: output stems
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 I hopped in with ROCm bright,

GPUs humming through the night,
Hybrid nets and STFT tune,
Debug scripts sing by the moon,
Stem-splitting carrots, code takes flight.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 49.23% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add AMD ROCm support and documentation' clearly and concisely summarizes the main objective of the pull request, which is to add AMD GPU/ROCm support and expand documentation. It is specific, relevant to the primary changes across the codebase, and reflects the core intent accurately.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Nitpick comments (17)
audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py (4)

197-197: Unused loop variable.

The loop variable k is not used within the loop body. Rename to _ to indicate it's intentionally unused.

Proposed fix
-        for k in range(len(split_ratios) + 1):
+        for _ in range(len(split_ratios) + 1):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 197, The
for-loop in hdemucs_rocm.py uses an unused loop variable `k` ("for k in
range(len(split_ratios) + 1):"); rename `k` to `_` to indicate the variable is
intentionally unused (i.e., change to "for _ in range(len(split_ratios) + 1):")
so linters won't flag an unused variable while keeping the loop behavior the
same; locate the loop referencing `split_ratios` in the file and update the
iterator variable accordingly.

634-634: Redundant int() call.

In Python 3, math.ceil() returns an int, so the outer int() call is unnecessary.

Proposed fix
-            le = int(math.ceil(x.shape[-1] / hl))
+            le = math.ceil(x.shape[-1] / hl)

The same applies to lines 654 and 656.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 634,
Remove the redundant outer int() wrappers around math.ceil(...) where lengths
are computed (e.g., the assignment to le using math.ceil(x.shape[-1] / hl) and
the similar calculations at the other two occurrences); math.ceil already
returns an int in Python 3, so replace expressions like le = int(math.ceil(...))
with le = math.ceil(...), and do the same for the other two variables (the other
math.ceil uses in the same scope).

51-51: Typo in comment.

"nornalize" should be "normalize".

Proposed fix
-            # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
+            # when summing gaussian, overscale raises as sqrt(n), so we normalize by that.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 51, Fix
the typo in the inline comment that currently reads "when summing gaussian,
overscale raises as sqrt(n), so we nornalize by that." — change "nornalize" to
"normalize" in the comment within hdemucs_rocm.py (the comment near the Gaussian
summing/overscale logic) so it reads "... so we normalize by that.".

214-214: Consider adding strict=True to zip() for safety.

In Python 3.10+, zip(..., strict=True) raises ValueError if iterables have different lengths. Since ratios and self.layers should always match in length, adding strict=True would catch configuration errors early.

Proposed fix
-        for ratio, layer in zip(ratios, self.layers):
+        for ratio, layer in zip(ratios, self.layers, strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 214, The
loop using zip(ratios, self.layers) can silently miss length mismatches; change
it to zip(ratios, self.layers, strict=True) so a ValueError is raised if their
lengths differ, updating the for-loop in the same location (the line with for
ratio, layer in zip(ratios, self.layers)) to include strict=True to catch
configuration errors early.
audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py (2)

6-8: Duplicate is_rocm() function - consider centralizing.

This is_rocm() function is identical to the one in stft.py (lines 4-6). Having duplicate implementations creates risk of divergence if the detection logic needs updating. Consider moving to a shared utility module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py` around lines 6 - 8, The
is_rocm() function in tfc_tdf_v3.py duplicates the implementation found in
stft.py; refactor by moving the detection logic into a single shared utility
(e.g., a new or existing utils module) and replace both local definitions with
imports of that shared is_rocm function; update references in tfc_tdf_v3.py and
stft.py to import the centralized is_rocm and remove the duplicate function to
avoid future divergence.

51-61: Avoid catching bare Exception.

The except Exception is too broad and may hide unexpected errors. Consider catching specific exceptions or at minimum logging the error.

Proposed improvement
         try:
             x = torch.stft(
                 x,
                 n_fft=self.n_fft,
                 hop_length=self.hop_length,
                 window=window,
                 center=True,
                 return_complex=False,
             )
-        except Exception as e:
+        except RuntimeError:
             # Fallback: try with return_complex=True
             x_complex = torch.stft(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py` around lines 51 - 61,
Replace the broad "except Exception as e:" around the torch.stft fallback with a
narrow, specific exception handler (e.g., "except RuntimeError as e:") and
ensure the caught error is logged before proceeding to the fallback; reference
the torch.stft call and the x_complex -> x stacking logic and either log the
exception via the module logger (including the error message/traceback) or
re-raise unexpected exception types instead of silencing them so failures in
torch.stft(return_complex=...) are visible.
Dockerfile.rocm (2)

8-12: Add --no-install-recommends to reduce image size.

As flagged by static analysis, adding --no-install-recommends prevents unnecessary packages from being installed.

Proposed fix
 # Install necessary packages
-RUN apt-get update && apt-get install -y \
+RUN apt-get update && apt-get install -y --no-install-recommends \
     ffmpeg \
     python3 \
     python3-pip \
     && rm -rf /var/lib/apt/lists/*
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Dockerfile.rocm` around lines 8 - 12, The apt install line in the Dockerfile
RUN step installs recommended packages increasing image size; update the RUN
command that includes apt-get install -y ffmpeg python3 python3-pip (the install
invocation) to add the --no-install-recommends flag to apt-get install and keep
the existing cleanup (rm -rf /var/lib/apt/lists/*) so only essential packages
are installed and image size is reduced.

24-26: Hardcoded GPU architecture limits portability.

The environment variables HSA_OVERRIDE_GFX_VERSION=10.3.2 and PYTORCH_ROCM_ARCH=gfx1030 are specific to RX 6600 series GPUs. Users with different AMD GPUs (e.g., RX 7000 series, MI-series) will need to override these.

Consider documenting this in comments or using build arguments:

Proposed improvement
-# Set environment variables for AMD RX 6600 series
-ENV HSA_OVERRIDE_GFX_VERSION=10.3.2
-ENV PYTORCH_ROCM_ARCH=gfx1030
+# Default environment variables for AMD RX 6600 series (gfx1032)
+# Override these for other GPUs:
+#   RX 7900: HSA_OVERRIDE_GFX_VERSION=11.0.0, PYTORCH_ROCM_ARCH=gfx1100
+#   MI250X:  HSA_OVERRIDE_GFX_VERSION=9.4.2, PYTORCH_ROCM_ARCH=gfx90a
+ARG HSA_OVERRIDE_GFX_VERSION=10.3.2
+ARG PYTORCH_ROCM_ARCH=gfx1030
+ENV HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION}
+ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Dockerfile.rocm` around lines 24 - 26, The Dockerfile hardcodes AMD GPU
targets via ENV HSA_OVERRIDE_GFX_VERSION=10.3.2 and ENV
PYTORCH_ROCM_ARCH=gfx1030 which reduces portability; change these to
configurable build arguments (e.g., ARG HSA_OVERRIDE_GFX_VERSION and ARG
PYTORCH_ROCM_ARCH) and set the ENV values from those args, and add a comment
above explaining they should be overridden for different GPUs (RX7000, MI
series, etc.) so users can pass --build-arg to customize the GPU target at build
time.
audio_separator/separator/uvr_lib_v5/stft.py (2)

64-74: Avoid catching bare Exception.

Similar to tfc_tdf_v3.py, catching bare Exception is too broad. RuntimeError is more appropriate for PyTorch operation failures.

Proposed fix
-        except Exception as e:
+        except RuntimeError:
             # Fallback: try with return_complex=True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/stft.py` around lines 64 - 74, The
except block currently catches a bare Exception around the torch.stft call;
change it to catch RuntimeError instead to avoid swallowing unrelated errors.
Specifically, in the code that computes stft_complex from reshaped_tensor (using
torch.stft with self.n_fft, self.hop_length and stft_window) replace "except
Exception as e" with "except RuntimeError as e" and keep the fallback logic that
constructs stft_output from stft_complex.real and .imag so only PyTorch
operation failures are caught.

4-6: Duplicate is_rocm() implementation.

As noted for tfc_tdf_v3.py, this function is duplicated. Consider extracting to a shared module like audio_separator/separator/uvr_lib_v5/utils.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/stft.py` around lines 4 - 6, Duplicate
is_rocm() exists in stft.py and tfc_tdf_v3.py; extract a single implementation
into a new utility module (e.g., create
audio_separator/separator/uvr_lib_v5/utils.py with def is_rocm(): return "+rocm"
in torch.__version__), remove the duplicate definitions from stft.py and
tfc_tdf_v3.py, and update those modules to import is_rocm from the new utils
module (use from .utils import is_rocm) so both files reference the same shared
function.
debug_results.txt (1)

1-166: Consider excluding diagnostic output from version control.

This file appears to be output from a debug script run on a specific system. Committing environment-specific diagnostic logs may not be appropriate for the repository. Consider:

  1. Adding debug_results.txt to .gitignore
  2. Or renaming to debug_results_example.txt to indicate it's a reference

Also note: Line 162 shows MDX model operations failed with an empty exception message, which may warrant investigation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_results.txt` around lines 1 - 166, This file debug_results.txt is a
machine-specific diagnostic log and should not be committed; either add
debug_results.txt to .gitignore or rename it to debug_results_example.txt (to
keep a sanitized example), and remove the committed debug_results.txt from the
repo history/next commit. Also, investigate the MDX model failure reported as
"MDX model operations: FAILED" (search for your MDX loader/initialization and
inference entrypoints) to capture and log the actual exception so future
diagnostic runs show a concrete error instead of an empty exception message.
README.md (1)

122-124: Clarify that CUDAExecutionProvider is expected for ROCm.

The success message states ONNXruntime has CUDAExecutionProvider available for ROCm setups, which may confuse users expecting a ROCm-specific provider. Consider clarifying this is expected behavior:

Proposed improvement
 💬 If successfully configured, you should see this log message when running `audio-separator --env_info`:
- `ONNXruntime has CUDAExecutionProvider available, enabling acceleration`
+ `ONNXruntime has CUDAExecutionProvider available, enabling acceleration`
+
+**Note:** On ROCm systems, the CUDAExecutionProvider is used for AMD GPU acceleration through ROCm's HIP compatibility layer. This is expected behavior.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` around lines 122 - 124, Update the README text that shows the
success log so users on ROCm aren’t confused: change the explanation around the
log message "ONNXruntime has CUDAExecutionProvider available, enabling
acceleration" shown by `audio-separator --env_info` to note that ONNXRuntime may
report CUDAExecutionProvider even when using ROCm (this is expected behavior and
indicates acceleration via ROCm/HIP mapping), and add a brief sentence telling
ROCm users that no further action is needed if ROCm drivers are installed.
audio_separator/separator/roformer/roformer_loader.py (1)

226-232: Consider using weights_only=True consistently in the new implementation path.

The legacy implementation attempts weights_only=True with a fallback for older PyTorch versions, but the new implementation at line 97 doesn't use this safer option.

🛡️ Suggested fix for safer model loading
             if os.path.exists(model_path):
-                state_dict = torch.load(model_path, map_location=device)
+                try:
+                    state_dict = torch.load(model_path, map_location=device, weights_only=True)
+                except TypeError:
+                    # For older torch versions without weights_only
+                    state_dict = torch.load(model_path, map_location=device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/roformer/roformer_loader.py` around lines 226 -
232, The new loading path that calls torch.load(model_path, map_location=device)
should use the safer weights_only=True option like the legacy path; update the
torch.load call in the new implementation (where model_path and device are used)
to pass weights_only=True and wrap it in the same TypeError fallback to maintain
compatibility with older PyTorch versions (i.e., try torch.load(...,
map_location=device, weights_only=True) and except TypeError: torch.load(...,
map_location=device)).
debug_rocm_issues.py (2)

628-628: Remove extraneous f prefix.

-                print(f"✓ Matrix multiplication completed")
+                print("✓ Matrix multiplication completed")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` at line 628, The print statement currently uses an
unnecessary f-string: replace the call print(f"✓ Matrix multiplication
completed") with a plain string print("✓ Matrix multiplication completed") to
remove the extraneous `f` prefix (locate the print call in debug_rocm_issues.py
and update it).

473-476: Remove extraneous f prefix from strings without placeholders.

Lines 474 and 476 have f-strings with no interpolation.

♻️ Suggested fix
         # Test memory layout of result
         if result_gpu.is_contiguous():
-            print(f"  Result is contiguous")
+            print("  Result is contiguous")
         else:
-            print(f"  Result is not contiguous")
+            print("  Result is not contiguous")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` around lines 473 - 476, Remove the unnecessary f-string
prefixes on the print calls that check contiguity; update the two prints that
call result_gpu.is_contiguous() to use plain string literals (e.g., print(" 
Result is contiguous") and print("  Result is not contiguous")) instead of
print(f"...") so there is no unused interpolation prefix.
audio_separator/separator/separator.py (1)

99-127: Mutable default arguments could cause unexpected behavior.

Using mutable dictionaries as default arguments is a Python gotcha. While these specific dicts aren't mutated in the code, it's a best practice to use None and initialize inside the function.

♻️ Suggested fix pattern
     def __init__(
         self,
         ...
-        mdx_params={
-            "hop_length": 1024,
-            ...
-        },
+        mdx_params=None,
         ...
     ):
+        if mdx_params is None:
+            mdx_params = {
+                "hop_length": 1024,
+                ...
+            }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/separator.py` around lines 99 - 127, The default
parameter dicts mdx_params, vr_params, demucs_params, and mdxc_params are
mutable and should be changed to None in the function/class signature and then
initialized inside the function (e.g., if mdx_params is None: mdx_params =
{...}) to avoid shared-state bugs; update the function or constructor that
currently declares those defaults (referencing mdx_params, vr_params,
demucs_params, mdxc_params) to perform this lazy initialization so each
call/instance gets its own fresh dictionary.
audio_separator/separator/architectures/mdxc_separator.py (1)

550-553: Consider adding strict=True to zip() for safety.

When zipping instruments with inferenced outputs, using strict=True would catch length mismatches early rather than silently dropping data.

♻️ Suggested fix
             for key, value in zip(
                 self.model_data_cfgdict.training.instruments,
                 inferenced_outputs.cpu().detach().numpy(),
+                strict=True,
             ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 550 -
553, The loop that pairs self.model_data_cfgdict.training.instruments with
inferenced_outputs.cpu().detach().numpy() should use zip(..., strict=True) to
surface any length mismatches immediately; update the for loop in mdxc_separator
(where the iteration over instruments and inferenced_outputs occurs) to call
zip(self.model_data_cfgdict.training.instruments,
inferenced_outputs.cpu().detach().numpy(), strict=True) and, if desired, wrap it
in a try/except to log a clear error message identifying MDXCSeparator and the
mismatched lengths before re-raising.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/architectures/vr_separator.py`:
- Around line 461-470: The two raise ValueError calls (the ones using the
message "Window size error: h1_shape[3] must be greater than h2_shape[3]") are
using f-strings without any placeholders; remove the unnecessary f prefix so the
literals are normal strings (i.e., change raise ValueError(f"...") to raise
ValueError("...")) for both occurrences near the code that does pred =
pred.detach().cpu().numpy(), pred = np.concatenate(pred, axis=2), and
mask.append(pred).

In `@audio_separator/separator/separator.py`:
- Around line 954-956: Replace the unsafe yaml.load call that uses
Loader=yaml.FullLoader with yaml.safe_load to avoid instantiating arbitrary
Python objects; locate the call that assigns model_data (the model_data =
yaml.load(...) invocation in separator.py) and change it to use
yaml.safe_load(open(..., encoding="utf-8")) so the YAML parsing is performed
safely while preserving the same file reading and assignment to model_data.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py`:
- Line 284: The constructor currently uses a mutable default for dconv_kw
(dconv_kw={}); change the signature to dconv_kw=None and inside the __init__ of
the class (same pattern used for HEncLayer) set if dconv_kw is None: dconv_kw =
{} so the instance receives a fresh dict; update any references to dconv_kw in
the __init__ body accordingly to use the new local variable.
- Line 81: The __init__ signature uses a mutable default dconv_kw={}, causing
shared-state bugs; change the parameter default to None (dconv_kw=None) and at
the start of the __init__ method (in the same class/function where dconv_kw is
defined) add a guard that assigns a fresh dict when dconv_kw is None (e.g., if
dconv_kw is None: dconv_kw = {}), ensuring each instance gets its own
dictionary; update any references to dconv_kw accordingly.

In `@audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py`:
- Line 523: The current ROCm detection using x_is_rocm = True if "rocm" in
str(original_device).lower() else False is unreliable because ROCm devices
appear as "cuda" strings; replace it with a proper HIP check (e.g., use
torch.version.hip or getattr(torch.version, "hip", None)) to set x_is_rocm =
True when torch.version.hip is not None/truthy; ensure torch is imported and use
the same variable names (original_device, x_is_rocm) so the rest of the code
remains unchanged.

In `@audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py`:
- Around line 414-418: The variable x_is_rocm is computed but never used; update
the control flow around the "Only move to CPU for MPS" branch to either (A) use
x_is_rocm to prevent moving tensors to CPU when ROCm is detected (e.g., treat if
x_is_mps and not x_is_rocm then move to CPU), or (B) remove the unused x_is_rocm
assignment if no ROCm special-casing is required; apply the same change for the
analogous logic in bs_roformer (ensure references to x_is_mps and x_is_rocm in
both mel_band_roformer.py and bs_roformer.py are consistent).

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py`:
- Around line 19-24: The ROCm handling in __call__ is incorrect: replace the
hardcoded run_on_cpu = False with logic matching stft.py by detecting ROCm
(e.g., check torch.version.hip is not None or hasattr(torch.version, "hip"))
and/or non-cuda device; set run_on_cpu = True when ROCm is detected or
x_is_non_cuda_device is True so the code moves tensors to CPU for STFT/ISTFT on
ROCm. Update both occurrences (the one initializing run_on_cpu at the top of
__call__ and the second occurrence around lines 75-78) and ensure
x_is_non_cuda_device is actually used in the decision.

In `@debug_rocm_issues.py`:
- Around line 405-406: The bare "except:" block in debug_rocm_issues.py (around
the handler containing "except: pass") is swallowing all exceptions; change it
to catch a specific exception (preferably OSError or at minimum Exception) by
replacing "except:" with "except OSError as e:" (or "except Exception as e:" if
appropriate), and handle it explicitly—e.g., log the error or comment why it's
safely ignored—rather than silently passing so issues aren't hidden.
- Around line 107-108: Move the unconditional module-level os.environ assignment
for "HSA_OVERRIDE_GFX_VERSION" into a small function (e.g.,
set_rocm_workarounds()) and call that function at the start of main() so
importing the module no longer produces side effects; specifically, remove the
top-level os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0", implement
set_rocm_workarounds() to set that env var (and any other ROCm workarounds), and
invoke set_rocm_workarounds() at the beginning of main() to apply the workaround
only when running the diagnostic.

In `@Dockerfile.rocm`:
- Around line 21-22: Replace the pip install that currently specifies
"audio-separator[gpu]" and also installs "onnxruntime-rocm": change the RUN
command to install the ROCm-specific extra (use "audio-separator[rocm]") and
remove the redundant installation of onnxruntime packages that would pull CUDA
artifacts; if the [rocm] extra in pyproject.toml already includes
onnxruntime-rocm, drop the separate onnxruntime-rocm install so only the ROCm
variant is installed.

In `@README.md`:
- Around line 247-269: The README contains a duplicated "Multiple CUDA library
versions may be needed" section; remove the redundant copy so the documentation
appears only once. Locate the duplicate block by searching for the header text
"Multiple CUDA library versions may be needed" and remove the second occurrence
(keep the first intact), ensuring the following "ROCm specific troubleshooting"
section remains directly after the retained block and that no other content is
altered.

---

Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 550-553: The loop that pairs
self.model_data_cfgdict.training.instruments with
inferenced_outputs.cpu().detach().numpy() should use zip(..., strict=True) to
surface any length mismatches immediately; update the for loop in mdxc_separator
(where the iteration over instruments and inferenced_outputs occurs) to call
zip(self.model_data_cfgdict.training.instruments,
inferenced_outputs.cpu().detach().numpy(), strict=True) and, if desired, wrap it
in a try/except to log a clear error message identifying MDXCSeparator and the
mismatched lengths before re-raising.

In `@audio_separator/separator/roformer/roformer_loader.py`:
- Around line 226-232: The new loading path that calls torch.load(model_path,
map_location=device) should use the safer weights_only=True option like the
legacy path; update the torch.load call in the new implementation (where
model_path and device are used) to pass weights_only=True and wrap it in the
same TypeError fallback to maintain compatibility with older PyTorch versions
(i.e., try torch.load(..., map_location=device, weights_only=True) and except
TypeError: torch.load(..., map_location=device)).

In `@audio_separator/separator/separator.py`:
- Around line 99-127: The default parameter dicts mdx_params, vr_params,
demucs_params, and mdxc_params are mutable and should be changed to None in the
function/class signature and then initialized inside the function (e.g., if
mdx_params is None: mdx_params = {...}) to avoid shared-state bugs; update the
function or constructor that currently declares those defaults (referencing
mdx_params, vr_params, demucs_params, mdxc_params) to perform this lazy
initialization so each call/instance gets its own fresh dictionary.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py`:
- Line 197: The for-loop in hdemucs_rocm.py uses an unused loop variable `k`
("for k in range(len(split_ratios) + 1):"); rename `k` to `_` to indicate the
variable is intentionally unused (i.e., change to "for _ in
range(len(split_ratios) + 1):") so linters won't flag an unused variable while
keeping the loop behavior the same; locate the loop referencing `split_ratios`
in the file and update the iterator variable accordingly.
- Line 634: Remove the redundant outer int() wrappers around math.ceil(...)
where lengths are computed (e.g., the assignment to le using
math.ceil(x.shape[-1] / hl) and the similar calculations at the other two
occurrences); math.ceil already returns an int in Python 3, so replace
expressions like le = int(math.ceil(...)) with le = math.ceil(...), and do the
same for the other two variables (the other math.ceil uses in the same scope).
- Line 51: Fix the typo in the inline comment that currently reads "when summing
gaussian, overscale raises as sqrt(n), so we nornalize by that." — change
"nornalize" to "normalize" in the comment within hdemucs_rocm.py (the comment
near the Gaussian summing/overscale logic) so it reads "... so we normalize by
that.".
- Line 214: The loop using zip(ratios, self.layers) can silently miss length
mismatches; change it to zip(ratios, self.layers, strict=True) so a ValueError
is raised if their lengths differ, updating the for-loop in the same location
(the line with for ratio, layer in zip(ratios, self.layers)) to include
strict=True to catch configuration errors early.

In `@audio_separator/separator/uvr_lib_v5/stft.py`:
- Around line 64-74: The except block currently catches a bare Exception around
the torch.stft call; change it to catch RuntimeError instead to avoid swallowing
unrelated errors. Specifically, in the code that computes stft_complex from
reshaped_tensor (using torch.stft with self.n_fft, self.hop_length and
stft_window) replace "except Exception as e" with "except RuntimeError as e" and
keep the fallback logic that constructs stft_output from stft_complex.real and
.imag so only PyTorch operation failures are caught.
- Around line 4-6: Duplicate is_rocm() exists in stft.py and tfc_tdf_v3.py;
extract a single implementation into a new utility module (e.g., create
audio_separator/separator/uvr_lib_v5/utils.py with def is_rocm(): return "+rocm"
in torch.__version__), remove the duplicate definitions from stft.py and
tfc_tdf_v3.py, and update those modules to import is_rocm from the new utils
module (use from .utils import is_rocm) so both files reference the same shared
function.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py`:
- Around line 6-8: The is_rocm() function in tfc_tdf_v3.py duplicates the
implementation found in stft.py; refactor by moving the detection logic into a
single shared utility (e.g., a new or existing utils module) and replace both
local definitions with imports of that shared is_rocm function; update
references in tfc_tdf_v3.py and stft.py to import the centralized is_rocm and
remove the duplicate function to avoid future divergence.
- Around line 51-61: Replace the broad "except Exception as e:" around the
torch.stft fallback with a narrow, specific exception handler (e.g., "except
RuntimeError as e:") and ensure the caught error is logged before proceeding to
the fallback; reference the torch.stft call and the x_complex -> x stacking
logic and either log the exception via the module logger (including the error
message/traceback) or re-raise unexpected exception types instead of silencing
them so failures in torch.stft(return_complex=...) are visible.

In `@debug_results.txt`:
- Around line 1-166: This file debug_results.txt is a machine-specific
diagnostic log and should not be committed; either add debug_results.txt to
.gitignore or rename it to debug_results_example.txt (to keep a sanitized
example), and remove the committed debug_results.txt from the repo history/next
commit. Also, investigate the MDX model failure reported as "MDX model
operations: FAILED" (search for your MDX loader/initialization and inference
entrypoints) to capture and log the actual exception so future diagnostic runs
show a concrete error instead of an empty exception message.

In `@debug_rocm_issues.py`:
- Line 628: The print statement currently uses an unnecessary f-string: replace
the call print(f"✓ Matrix multiplication completed") with a plain string
print("✓ Matrix multiplication completed") to remove the extraneous `f` prefix
(locate the print call in debug_rocm_issues.py and update it).
- Around line 473-476: Remove the unnecessary f-string prefixes on the print
calls that check contiguity; update the two prints that call
result_gpu.is_contiguous() to use plain string literals (e.g., print("  Result
is contiguous") and print("  Result is not contiguous")) instead of
print(f"...") so there is no unused interpolation prefix.

In `@Dockerfile.rocm`:
- Around line 8-12: The apt install line in the Dockerfile RUN step installs
recommended packages increasing image size; update the RUN command that includes
apt-get install -y ffmpeg python3 python3-pip (the install invocation) to add
the --no-install-recommends flag to apt-get install and keep the existing
cleanup (rm -rf /var/lib/apt/lists/*) so only essential packages are installed
and image size is reduced.
- Around line 24-26: The Dockerfile hardcodes AMD GPU targets via ENV
HSA_OVERRIDE_GFX_VERSION=10.3.2 and ENV PYTORCH_ROCM_ARCH=gfx1030 which reduces
portability; change these to configurable build arguments (e.g., ARG
HSA_OVERRIDE_GFX_VERSION and ARG PYTORCH_ROCM_ARCH) and set the ENV values from
those args, and add a comment above explaining they should be overridden for
different GPUs (RX7000, MI series, etc.) so users can pass --build-arg to
customize the GPU target at build time.

In `@README.md`:
- Around line 122-124: Update the README text that shows the success log so
users on ROCm aren’t confused: change the explanation around the log message
"ONNXruntime has CUDAExecutionProvider available, enabling acceleration" shown
by `audio-separator --env_info` to note that ONNXRuntime may report
CUDAExecutionProvider even when using ROCm (this is expected behavior and
indicates acceleration via ROCm/HIP mapping), and add a brief sentence telling
ROCm users that no further action is needed if ROCm drivers are installed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 3cc518b4-dbc9-40ad-ba2b-906762aa3ceb

📥 Commits

Reviewing files that changed from the base of the PR and between 3f58f95 and b18b0f2.

📒 Files selected for processing (16)
  • Dockerfile.rocm
  • README.md
  • audio_separator/separator/architectures/demucs_separator.py
  • audio_separator/separator/architectures/mdxc_separator.py
  • audio_separator/separator/architectures/vr_separator.py
  • audio_separator/separator/common_separator.py
  • audio_separator/separator/roformer/roformer_loader.py
  • audio_separator/separator/separator.py
  • audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py
  • audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py
  • audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py
  • audio_separator/separator/uvr_lib_v5/stft.py
  • audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py
  • debug_results.txt
  • debug_rocm_issues.py
  • pyproject.toml

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (10)
debug_rocm_issues.py (2)

474-477: Remove extraneous f prefix from strings without placeholders.

These f-strings contain no placeholders and should be regular strings per Ruff F541.

♻️ Proposed fix
         # Test memory layout of result
         if result_gpu.is_contiguous():
-            print(f"  Result is contiguous")
+            print("  Result is contiguous")
         else:
-            print(f"  Result is not contiguous")
+            print("  Result is not contiguous")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` around lines 474 - 477, The two print statements inside
the result_gpu.is_contiguous() conditional use f-strings with no placeholders;
change them to plain strings by removing the leading 'f' so the prints become
standard strings (locate the prints in the result_gpu.is_contiguous() branch in
debug_rocm_issues.py).

236-237: Diagnostic script only runs test_model_types() despite having multiple test functions.

The main() function calls test_pytorch_rocm_setup() implicitly through the section prints but only explicitly calls test_model_types(). Consider running all test functions or documenting why others are skipped.

♻️ Proposed fix to run all tests
     # Check memory layout
     print_section("0.4 Memory Layout Check")
     # ... existing code ...

-    # Run only MDX model test for focused debugging
-    test_model_types()
+    # Run all diagnostic tests
+    test_pytorch_rocm_setup()
+    test_onnxruntime_setup()
+    test_stft_operations()
+    test_memory_allocation()
+    test_model_types()

     print("\n" + "=" * 50)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` around lines 236 - 237, The main() currently only
invokes test_model_types(), so other diagnostics like test_pytorch_rocm_setup()
and test_onnxruntime() never run; update main() to explicitly call all test
functions (e.g., test_pytorch_rocm_setup(), test_onnxruntime(),
test_model_types()) in the intended order or add clear documentation/logging in
main() explaining why certain tests are intentionally skipped; locate the calls
in main() and add the missing function invocations (or comments) so the
diagnostic script executes every intended test.
audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py (3)

221-221: Remove unnecessary int() casts on round() results.

In Python 3, round() returns an int when called with a single argument, making the outer int() call redundant.

♻️ Proposed fix
                 if ratio == 1:
                     limit = Fr
                     frames = -1
                 else:
-                    limit = int(round(Fr * ratio))
+                    limit = round(Fr * ratio)
                     le = limit - start

Also applies to: 242-242

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 221,
Remove the redundant int() casts around round() results in hdemucs_rocm.py:
replace occurrences like "limit = int(round(Fr * ratio))" (and the similar
occurrence at the other location) with just "limit = round(Fr * ratio)" so that
the value uses the int returned by round() in Python 3; update any nearby
assignments that wrap round(...) with int(...) accordingly (search for
"int(round(" to find all instances in this module, including the one at the
second mentioned location).

209-209: Prefix unused unpacked variables with underscore.

Variables B, C, T are unpacked but never used in the forward method. Prefix with _ to indicate intentional non-use.

♻️ Proposed fix
     def forward(self, x, skip=None, length=None):
-        B, C, Fr, T = x.shape
+        _B, _C, Fr, _T = x.shape

         ratios = list(self.split_ratios) + [1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` at line 209, In
the forward method where the input shape is unpacked (currently "B, C, Fr, T =
x.shape"), prefix the unused variables B, C, and T with an underscore (e.g., _B,
_C, _T) so they are marked as intentionally unused; update the unpacking near
the start of hdemucs_rocm.forward (the line unpacking x.shape) to use _B, _C,
Fr, _T and ensure downstream code still references only Fr.

33-34: Assertion may fail silently in production if optimizations are enabled.

The assertions at lines 33-34 verify padding correctness but will be stripped when Python runs with -O flag. Consider using explicit checks with exceptions for critical invariants if this code runs in optimized mode.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py` around lines 33
- 34, Replace the two assert statements that check padding correctness (the
lines using out.shape and comparing out[..., padding_left : padding_left +
length] == x0) with explicit runtime checks that raise a clear exception (e.g.,
ValueError or RuntimeError) when the invariants fail; use the same symbols (out,
padding_left, padding_right, length, x0) and perform a strict shape check for
out.shape[-1] == length + padding_left + padding_right and an element-wise
equality check (using the appropriate array/tensor equality helper in this
module) and include the actual values in the error message to aid debugging.
audio_separator/separator/uvr_lib_v5/stft.py (2)

52-71: Catch a more specific exception type.

The bare Exception catch is too broad per Ruff BLE001. Since this fallback handles torch.stft compatibility issues, catching RuntimeError or TypeError would be more appropriate.

♻️ Proposed fix
         # Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
         try:
             stft_output = torch.stft(
                 reshaped_tensor,
                 n_fft=self.n_fft,
                 hop_length=self.hop_length,
                 window=stft_window,
                 center=True,
                 return_complex=False,
             )
-        except Exception as e:
+        except RuntimeError:
             # Fallback: try with return_complex=True
             stft_complex = torch.stft(
                 reshaped_tensor,
                 n_fft=self.n_fft,
                 hop_length=self.hop_length,
                 window=stft_window,
                 center=True,
                 return_complex=True,
             )
             stft_output = torch.stack([stft_complex.real, stft_complex.imag], dim=-1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/stft.py` around lines 52 - 71, The
current broad except Exception around the torch.stft call should be narrowed to
only the compatibility-related errors; change the handler to catch RuntimeError
and TypeError instead of Exception so the fallback path (which computes
stft_complex and sets stft_output from stft_complex.real/imag for
reshaped_tensor) only runs for those specific torch.stft compatibility failures
and doesn't swallow unrelated exceptions.

25-27: Unused variable original_dtype in forward path.

The original_dtype is stored but never used. If dtype restoration is intended, implement it; otherwise, remove the unused variable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/stft.py` around lines 25 - 27, The
variable original_dtype is captured but never used in the forward path; either
remove its assignment or use it to restore tensor dtypes after processing:
locate the forward function where original_dtype is set (search for
original_dtype and input_tensor) and if dtype preservation is intended, cast the
final output(s) back to original_dtype (e.g., convert result tensors to
original_dtype before returning), otherwise delete the original_dtype assignment
to avoid the unused variable. Ensure any restoration happens after all CUDA/CPU
device logic (around run_on_cpu and subsequent transforms) so device and dtype
are correctly reconciled.
audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py (2)

25-27: Remove unused original_dtype variables.

original_dtype is stored in both __call__ (line 25) and inverse (line 67) but never used. Remove these or implement dtype restoration if intended.

♻️ Proposed fix
         if run_on_cpu:
             x = x.cpu()

         # Ensure FP32 for stability on ROCm and non-standard devices
-        original_dtype = x.dtype
         if x.dtype in (torch.float16, torch.bfloat16):
             x = x.float()

Also applies to: 67-69

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py` around lines 25 - 27,
Remove the unused variable original_dtype in both the __call__ and inverse
methods of tfc_tdf_v3.py: delete the lines that assign original_dtype = x.dtype
(and any corresponding unused references) since dtype restoration is not
implemented; if you intended to preserve dtype instead, replace those lines by
storing dtype and then casting the output back to that dtype (e.g., out =
out.to(original_dtype)) at the end of the method. Ensure the change touches the
original_dtype assignments in __call__ and inverse only.

5-6: Remove unused is_rocm import.

The is_rocm function is imported but never used in this file. Either remove the unused import or integrate ROCm-aware logic as done in stft.py.

♻️ Proposed fix
 import torch
 import torch.nn as nn
 from functools import partial
 
-from audio_separator.separator.uvr_lib_v5.utils import is_rocm
-
-
+
 class STFT:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py` around lines 5 - 6,
Remove the unused import is_rocm from
audio_separator.separator.uvr_lib_v5.tfc_tdf_v3.py or add ROCm-aware branching
similar to stft.py; specifically either delete the line importing is_rocm or use
is_rocm() inside the relevant functions (e.g., any tensor/device setup in
tfc_tdf_v3 functions) to select the correct backend/device code path as done in
stft.py so the import is actually used.
Dockerfile.rocm (1)

1-34: Consider adding a non-root user for security.

The container runs as root by default, which Trivy flagged as a security concern (DS-0002). For production use, consider adding a non-root user.

♻️ Proposed fix
 # Install audio-separator with GPU support
 RUN --mount=type=cache,target=/root/.cache \
     pip3 install "audio-separator[gpu]" onnxruntime-rocm

+# Create non-root user for security
+RUN useradd -m -s /bin/bash separator
+USER separator
+WORKDIR /home/separator/workdir
+
 # Default environment variables for AMD RX 6600 series (gfx1032)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Dockerfile.rocm` around lines 1 - 34, Add a non-root user and switch to it
before CMD/ENTRYPOINT: create a group and user (optionally via build args like
UID/GID), chown the WORKDIR (/workdir) and any cache or pip directories to that
user, set HOME and USER environment variables, and add a USER instruction so the
container no longer runs as root; ensure this change is applied before the
ENTRYPOINT ["audio-separator"] so the process starts as the non-root account.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/uvr_lib_v5/stft.py`:
- Line 3: The import in stft.py that reads "from
audio_separator.separator.uvr_lib_v5.utils import is_rocm" fails because
utils.py doesn't exist and is_rocm is a boolean attribute elsewhere; remove that
import and instead reference the existing boolean is_rocm from the module that
defines it (e.g., import the flag from separator.py or common_separator.py where
is_rocm is declared) or pass the boolean into the functions that need it; apply
the same fix in tfc_tdf_v3.py (replace the non-existent utils import with an
import of the is_rocm boolean from the correct module or adjust call sites to
accept the flag).

---

Nitpick comments:
In `@audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py`:
- Line 221: Remove the redundant int() casts around round() results in
hdemucs_rocm.py: replace occurrences like "limit = int(round(Fr * ratio))" (and
the similar occurrence at the other location) with just "limit = round(Fr *
ratio)" so that the value uses the int returned by round() in Python 3; update
any nearby assignments that wrap round(...) with int(...) accordingly (search
for "int(round(" to find all instances in this module, including the one at the
second mentioned location).
- Line 209: In the forward method where the input shape is unpacked (currently
"B, C, Fr, T = x.shape"), prefix the unused variables B, C, and T with an
underscore (e.g., _B, _C, _T) so they are marked as intentionally unused; update
the unpacking near the start of hdemucs_rocm.forward (the line unpacking
x.shape) to use _B, _C, Fr, _T and ensure downstream code still references only
Fr.
- Around line 33-34: Replace the two assert statements that check padding
correctness (the lines using out.shape and comparing out[..., padding_left :
padding_left + length] == x0) with explicit runtime checks that raise a clear
exception (e.g., ValueError or RuntimeError) when the invariants fail; use the
same symbols (out, padding_left, padding_right, length, x0) and perform a strict
shape check for out.shape[-1] == length + padding_left + padding_right and an
element-wise equality check (using the appropriate array/tensor equality helper
in this module) and include the actual values in the error message to aid
debugging.

In `@audio_separator/separator/uvr_lib_v5/stft.py`:
- Around line 52-71: The current broad except Exception around the torch.stft
call should be narrowed to only the compatibility-related errors; change the
handler to catch RuntimeError and TypeError instead of Exception so the fallback
path (which computes stft_complex and sets stft_output from
stft_complex.real/imag for reshaped_tensor) only runs for those specific
torch.stft compatibility failures and doesn't swallow unrelated exceptions.
- Around line 25-27: The variable original_dtype is captured but never used in
the forward path; either remove its assignment or use it to restore tensor
dtypes after processing: locate the forward function where original_dtype is set
(search for original_dtype and input_tensor) and if dtype preservation is
intended, cast the final output(s) back to original_dtype (e.g., convert result
tensors to original_dtype before returning), otherwise delete the original_dtype
assignment to avoid the unused variable. Ensure any restoration happens after
all CUDA/CPU device logic (around run_on_cpu and subsequent transforms) so
device and dtype are correctly reconciled.

In `@audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py`:
- Around line 25-27: Remove the unused variable original_dtype in both the
__call__ and inverse methods of tfc_tdf_v3.py: delete the lines that assign
original_dtype = x.dtype (and any corresponding unused references) since dtype
restoration is not implemented; if you intended to preserve dtype instead,
replace those lines by storing dtype and then casting the output back to that
dtype (e.g., out = out.to(original_dtype)) at the end of the method. Ensure the
change touches the original_dtype assignments in __call__ and inverse only.
- Around line 5-6: Remove the unused import is_rocm from
audio_separator.separator.uvr_lib_v5.tfc_tdf_v3.py or add ROCm-aware branching
similar to stft.py; specifically either delete the line importing is_rocm or use
is_rocm() inside the relevant functions (e.g., any tensor/device setup in
tfc_tdf_v3 functions) to select the correct backend/device code path as done in
stft.py so the import is actually used.

In `@debug_rocm_issues.py`:
- Around line 474-477: The two print statements inside the
result_gpu.is_contiguous() conditional use f-strings with no placeholders;
change them to plain strings by removing the leading 'f' so the prints become
standard strings (locate the prints in the result_gpu.is_contiguous() branch in
debug_rocm_issues.py).
- Around line 236-237: The main() currently only invokes test_model_types(), so
other diagnostics like test_pytorch_rocm_setup() and test_onnxruntime() never
run; update main() to explicitly call all test functions (e.g.,
test_pytorch_rocm_setup(), test_onnxruntime(), test_model_types()) in the
intended order or add clear documentation/logging in main() explaining why
certain tests are intentionally skipped; locate the calls in main() and add the
missing function invocations (or comments) so the diagnostic script executes
every intended test.

In `@Dockerfile.rocm`:
- Around line 1-34: Add a non-root user and switch to it before CMD/ENTRYPOINT:
create a group and user (optionally via build args like UID/GID), chown the
WORKDIR (/workdir) and any cache or pip directories to that user, set HOME and
USER environment variables, and add a USER instruction so the container no
longer runs as root; ensure this change is applied before the ENTRYPOINT
["audio-separator"] so the process starts as the non-root account.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: cf805857-3eb8-4d80-924c-a654220a11a2

📥 Commits

Reviewing files that changed from the base of the PR and between b18b0f2 and d51648f.

📒 Files selected for processing (7)
  • .gitignore
  • Dockerfile.rocm
  • audio_separator/separator/uvr_lib_v5/demucs/hdemucs_rocm.py
  • audio_separator/separator/uvr_lib_v5/stft.py
  • audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py
  • debug_results.txt
  • debug_rocm_issues.py
✅ Files skipped from review due to trivial changes (1)
  • .gitignore

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@README.md`:
- Line 131: Replace the incorrect GPU extras group in the README install command
so ROCm users get the ROCm-specific runtime; specifically, change the install
instruction that uses the "gpu" extras group to use "rocm" instead (the README
line currently showing pip install "audio-separator[gpu]" should reference the
"rocm" extras group which installs onnxruntime-rocm rather than
onnxruntime-gpu). Update any other occurrences (e.g., the second occurrence
around line 139) to use "rocm" to match pyproject.toml's extras definitions.
- Line 144: Update the README entry that lists ONNX Runtime packages so ROCm
users are not pointed to the CUDA package: replace the incorrect
`onnxruntime-gpu` mention for ROCm with `onnxruntime-rocm` (the same package
named in the pyproject.toml rocm extras) so the ONNX Runtime line reads the
correct set including `onnxruntime-rocm` rather than `onnxruntime-gpu`.
- Line 128: The README contains inconsistent PyTorch ROCm install URLs: replace
the nightly rocm7.2 and the stable rocm5.7 occurrences with the consistent
stable ROCm 7.2 URL; specifically update the pip install lines that currently
reference "https://download.pytorch.org/whl/nightly/rocm7.2/" and
"https://download.pytorch.org/whl/rocm5.7" to use
"https://download.pytorch.org/whl/rocm7.2" so docs match the Dockerfile.rocm's
tested configuration.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 0aa3fcb5-933a-436e-8445-53e8e8be1d10

📥 Commits

Reviewing files that changed from the base of the PR and between d51648f and c184b76.

📒 Files selected for processing (1)
  • README.md

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (2)
debug_rocm_issues.py (2)

402-407: ⚠️ Potential issue | 🟡 Minor

Don’t swallow install-path scan failures.

This bare except: pass hides permission and I/O errors in the middle of the ROCm installation diagnostic. Catch OSError and print a short note instead so users can tell whether the path was absent or just unreadable.

♻️ Proposed fix
-                    except:
-                        pass
+                    except OSError as exc:
+                        print(f"  Could not inspect {path}: {exc}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` around lines 402 - 407, Replace the bare except in the
os.listdir(path) scan with an OSError handler so you don't swallow
permission/I/O errors: in the try/except around os.listdir(path) (the block that
checks "lib" or "bin" and prints "ROCm components detected in {path}"), catch
OSError and print a short diagnostic like "Unable to read {path}: {error}" (or
similar) instead of silently passing; keep other exceptions uncaught so they
surface.

107-109: ⚠️ Potential issue | 🟠 Major

Don’t hard-code ROCm env overrides at import time.

These assignments mutate the process as soon as the module is imported, so the report no longer reflects the user’s real environment. They also happen after import torch, which can be too late for ROCm/PyTorch initialization, and PYTORCH_ROCM_ARCH="gfx1030" will be wrong for non-gfx1030 cards.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_rocm_issues.py` around lines 107 - 109, The module currently hard-sets
HSA_OVERRIDE_GFX_VERSION and PYTORCH_ROCM_ARCH at import time (after import
torch), which mutates the process and can be too late or incorrect; remove these
top‑level os.environ assignments and instead provide a function (e.g.,
set_rocm_env or configure_rocm_env) that (1) checks for existing environment
variables and only sets defaults if they are unset, (2) detects the GPU/ROCm
version dynamically (via rocminfo or a safe detection routine) to choose the
correct PYTORCH_ROCM_ARCH value instead of forcing "gfx1030", and (3) must be
called before importing torch when the caller explicitly wants to override the
environment. Ensure the code references HSA_OVERRIDE_GFX_VERSION and
PYTORCH_ROCM_ARCH only within that function and do not perform mutations at
module import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/uvr_lib_v5/utils.py`:
- Around line 4-6: The is_rocm() helper currently checks "+rocm" in
torch.__version__, which is unreliable; update is_rocm() to return
getattr(torch.version, "hip", None) is not None so it uses the supported runtime
indicator for ROCm/HIP builds (this affects decisions in stft.py that move STFT
processing to CPU versus GPU). Locate the is_rocm function in utils.py and
replace the version-string check with the getattr(torch.version, "hip", None) is
not None pattern so ROCm builds are detected robustly.

In `@debug_rocm_issues.py`:
- Around line 684-687: The loop `for t in tensors: del t` doesn’t remove the
references in the `tensors` list, so clear those references before calling
`torch.cuda.empty_cache()`; replace that loop with a proper clear such as
`tensors.clear()` or `del tensors[:]` (or set each element to None) and then
call `gc.collect()` (optional) followed by `torch.cuda.empty_cache()` so GPU
memory is actually freed; ensure you update code around the `tensors` variable
and the existing `torch.cuda.empty_cache()` call.
- Around line 456-478: The contiguity check uses result_gpu which is only
assigned inside the per-n_fft try block, so if all torch.stft calls fail an
UnboundLocalError occurs; initialize a sentinel (e.g., result_gpu = None) before
the for-loop and after the loop guard the is_contiguous() check by first
verifying result_gpu is not None (and optionally log that all STFT attempts
failed) so the subsequent diagnostics continue; update references around the for
n_fft loop and the final contiguity test to use this sentinel logic.
- Around line 24-42: The subprocess.run calls that invoke external tools (the
rocminfo call that writes to rocm_info["rocminfo"] and the hipconfig
subprocess.run that writes the hipconfig output) must include check=True so
non-zero exits raise CalledProcessError and get handled by your existing except
block; update both subprocess.run invocations to add check=True while keeping
capture_output=True, text=True, timeout=10, and ensure the except still catches
subprocess.CalledProcessError/FileNotFoundError/TimeoutExpired to set rocm_info
entries appropriately.
- Around line 233-240: The session creation is using a non-existent
"dummy_model.onnx" before the in-memory test model is built, causing false
failures; move the minimal MatMul model construction so it runs before creating
any ort.InferenceSession, then create sessions for both ROCMExecutionProvider
and CPUExecutionProvider (and the CUDA branch) by passing
model.SerializeToString() (not a filename) and the previously configured
session_options; ensure you update the code paths around session_options and
ort.InferenceSession to use the serialized in-memory model in both branches.

---

Duplicate comments:
In `@debug_rocm_issues.py`:
- Around line 402-407: Replace the bare except in the os.listdir(path) scan with
an OSError handler so you don't swallow permission/I/O errors: in the try/except
around os.listdir(path) (the block that checks "lib" or "bin" and prints "ROCm
components detected in {path}"), catch OSError and print a short diagnostic like
"Unable to read {path}: {error}" (or similar) instead of silently passing; keep
other exceptions uncaught so they surface.
- Around line 107-109: The module currently hard-sets HSA_OVERRIDE_GFX_VERSION
and PYTORCH_ROCM_ARCH at import time (after import torch), which mutates the
process and can be too late or incorrect; remove these top‑level os.environ
assignments and instead provide a function (e.g., set_rocm_env or
configure_rocm_env) that (1) checks for existing environment variables and only
sets defaults if they are unset, (2) detects the GPU/ROCm version dynamically
(via rocminfo or a safe detection routine) to choose the correct
PYTORCH_ROCM_ARCH value instead of forcing "gfx1030", and (3) must be called
before importing torch when the caller explicitly wants to override the
environment. Ensure the code references HSA_OVERRIDE_GFX_VERSION and
PYTORCH_ROCM_ARCH only within that function and do not perform mutations at
module import time.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 37bc4423-d72c-4999-bb6e-7c537da264fd

📥 Commits

Reviewing files that changed from the base of the PR and between c184b76 and 1dadc0b.

📒 Files selected for processing (2)
  • audio_separator/separator/uvr_lib_v5/utils.py
  • debug_rocm_issues.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
audio_separator/separator/separator.py (1)

108-153: ⚠️ Potential issue | 🔴 Critical

Restore the removed ensemble kwargs and avoid shared dict defaults.

The constructor body still reads ensemble_algorithm, ensemble_weights, and ensemble_preset at Lines 269-282, but this signature no longer accepts them, so Separator() now raises NameError during initialization. While fixing that, switch the arch-config defaults to None; these dict literals are currently shared across instances.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/separator.py` around lines 108 - 153, The __init__
signature was changed but the constructor still reads ensemble_algorithm,
ensemble_weights, and ensemble_preset, causing a NameError; add these three
parameters back to the Separator.__init__ signature (with sensible defaults,
e.g., None or existing presets) so the body can reference them. Also avoid using
mutable dict literals as default args for mdx_params, vr_params, demucs_params,
and mdxc_params: change their defaults to None in the __init__ signature and
inside the constructor set them to the original dict literals if the
corresponding argument is None (e.g., if mdx_params is None: mdx_params =
{...}), ensuring each instance gets its own copy. Ensure references to
ensemble_algorithm, ensemble_weights, and ensemble_preset in the constructor
remain unchanged.
pyproject.toml (1)

49-66: ⚠️ Potential issue | 🟠 Major

Regenerate and commit poetry.lock with this dependency change.

CI is already failing poetry install -E gpu because pyproject.toml changed without a matching lockfile refresh. Until the updated lockfile is included, clean installs from this PR stay broken.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 49 - 66, The pyproject.toml extras and
dependency entries (e.g., onnxruntime-rocm, [tool.poetry.extras]
gpu/onxruntime-gpu) were changed but the poetry.lock was not updated; regenerate
the lockfile and commit it so CI installs with the new extras succeed: run
Poetry to refresh the lock (e.g., run poetry lock or poetry lock --no-update as
appropriate), verify a clean install with the gpu extra (poetry install -E gpu),
and commit the updated poetry.lock alongside the pyproject.toml change.
♻️ Duplicate comments (3)
audio_separator/separator/separator.py (1)

1027-1029: ⚠️ Potential issue | 🟠 Major

Still parsing downloaded YAML with yaml.FullLoader.

download_model_files() can fetch model configs from remote repos, and this method still parses them with yaml.load(..., Loader=yaml.FullLoader). Please switch to yaml.safe_load and a context manager before shipping more remote model/config support.

Run the following script to confirm the remote-download path and the loader currently in use:

#!/bin/bash
sed -n '950,1035p' audio_separator/separator/separator.py
rg -n "yaml\.load|yaml\.safe_load|FullLoader" audio_separator/separator/separator.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/separator.py` around lines 1027 - 1029, The code at
the yaml parsing site uses yaml.load with yaml.FullLoader when reading
model_data_yaml_filepath (used to produce model_data), which is unsafe for
remote configs; change the logic that reads the file (the block that assigns
model_data) to open the file with a context manager (with
open(model_data_yaml_filepath, encoding="utf-8") as f:) and parse via
yaml.safe_load(f) instead of yaml.load(..., Loader=yaml.FullLoader); ensure any
other occurrences in the same module (and in the download_model_files() code
path) are updated to safe_load to avoid insecure parsing.
README.md (2)

207-219: ⚠️ Potential issue | 🟠 Major

Split the recovery commands by backend or make this section CUDA-only again.

After Line 209 distinguishes [gpu] vs [rocm], the example immediately below unconditionally reinstalls onnxruntime-gpu. A ROCm user following these steps will replace the AMD runtime with the CUDA package.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` around lines 207 - 219, The recovery commands currently
unconditionally reinstall onnxruntime-gpu which will break ROCm users; update
the README section that follows the `[gpu]` vs `[rocm]` note to either split the
example into two backend-specific blocks or revert this paragraph to be
CUDA-only: keep the generic uninstall/cache steps (`pip uninstall torch
onnxruntime` and `pip cache purge`) but show separate reinstall commands — `pip
install --force-reinstall torch torchvision torchaudio` + `pip install
--force-reinstall onnxruntime-gpu` for CUDA/NVIDIA (`[gpu]`) and the appropriate
ROCm runtime install for AMD (`[rocm]`) — and clearly label each block so
readers for `onnxruntime-gpu` are only those using `[gpu]`.

122-153: ⚠️ Potential issue | 🟠 Major

The ROCm installation section still mixes ROCm and CUDA instructions.

This block currently disagrees with the new ROCm packaging/runtime path: the two pip snippets use different ROCm wheel indexes, both installs still use [gpu], the package list still mentions onnxruntime-gpu, and the success signal only names CUDAExecutionProvider even though Separator.configure_rocm() prefers ROCMExecutionProvider. AMD users can end up installing the wrong runtime or think a correct ROCm setup failed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` around lines 122 - 153, The README mixes CUDA and ROCm
instructions; update the ROCm section to consistently use the same ROCm wheel
index, stop referencing CUDA artifacts, and reflect what
Separator.configure_rocm() expects: replace both pip snippets to use a single
ROCm wheel index (e.g., https://download.pytorch.org/whl/rocm7.2), change
package install examples to install PyTorch ROCm first and then pip install
"audio-separator[rocm]" (or the correct extra for ROCm), remove mentions of
onnxruntime-gpu from the Required ROCm Packages (keep onnxruntime and
onnxruntime-rocm only), and change the success log example to "ONNXruntime has
ROCMExecutionProvider available, enabling acceleration" to match
Separator.configure_rocm().
🧹 Nitpick comments (1)
debug_minimal.py (1)

26-44: Include tracebacks in the debug report.

Both exception paths only keep str(e), which drops the stack and import/provider context users usually need to diagnose ROCm failures. For a troubleshooting script, capture the traceback too.

🛠️ Suggested change
+import traceback
...
     except Exception as e:
         output.append(f"Error: {e}")
+        output.append(traceback.format_exc())
...
     except Exception as e:
         output.append(f"Error: {e}")
+        output.append(traceback.format_exc())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@debug_minimal.py` around lines 26 - 44, The except blocks in the Torch and
ONNX Runtime checks currently append only str(e), losing stack traces; update
the exception handlers in the try/except around importing torch and onnxruntime
(the blocks that append "Error: {e}") to also capture and append the full
traceback (e.g., via traceback.format_exc()) so the debug output includes the
stack and import/provider context—import the traceback module at top if not
present and include the formatted traceback alongside the existing error message
for both exception handlers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/separator.py`:
- Around line 463-475: The fallback path incorrectly calls configure_cuda() when
onnxruntime_rocm_package is present and "+rocm" appears in torch_version but
only CUDAExecutionProvider is available, which prevents self.is_rocm from being
set and causes load_model() to pass is_rocm=False into CommonSeparator; change
the branch so that when onnxruntime_rocm_package is not None and "+rocm" in
torch_version you call configure_rocm(ort_providers) even in the
CUDAExecutionProvider fallback (instead of configure_cuda), ensuring
self.is_rocm is set correctly before load_model() and CommonSeparator.
- Around line 1264-1269: The guard that raises "Initialization failed..." runs
before the multi-file ensemble path, causing multi-model runs to always error;
change the logic in the separation entry (around the check using
self.model_instance and self.model_filename) so the code first checks if
isinstance(self.model_filename, list) and len(self.model_filename) > 1 and
dispatches to _separate_ensemble() (or otherwise handles the ensemble) before
enforcing that self.model_instance is present, or alter the guard to allow the
ensemble case through (e.g., require model_instance only when model_filename is
not a multi-file list); refer to load_model(), model_filename, model_instance,
and _separate_ensemble() to locate and fix the flow.

---

Outside diff comments:
In `@audio_separator/separator/separator.py`:
- Around line 108-153: The __init__ signature was changed but the constructor
still reads ensemble_algorithm, ensemble_weights, and ensemble_preset, causing a
NameError; add these three parameters back to the Separator.__init__ signature
(with sensible defaults, e.g., None or existing presets) so the body can
reference them. Also avoid using mutable dict literals as default args for
mdx_params, vr_params, demucs_params, and mdxc_params: change their defaults to
None in the __init__ signature and inside the constructor set them to the
original dict literals if the corresponding argument is None (e.g., if
mdx_params is None: mdx_params = {...}), ensuring each instance gets its own
copy. Ensure references to ensemble_algorithm, ensemble_weights, and
ensemble_preset in the constructor remain unchanged.

In `@pyproject.toml`:
- Around line 49-66: The pyproject.toml extras and dependency entries (e.g.,
onnxruntime-rocm, [tool.poetry.extras] gpu/onxruntime-gpu) were changed but the
poetry.lock was not updated; regenerate the lockfile and commit it so CI
installs with the new extras succeed: run Poetry to refresh the lock (e.g., run
poetry lock or poetry lock --no-update as appropriate), verify a clean install
with the gpu extra (poetry install -E gpu), and commit the updated poetry.lock
alongside the pyproject.toml change.

---

Duplicate comments:
In `@audio_separator/separator/separator.py`:
- Around line 1027-1029: The code at the yaml parsing site uses yaml.load with
yaml.FullLoader when reading model_data_yaml_filepath (used to produce
model_data), which is unsafe for remote configs; change the logic that reads the
file (the block that assigns model_data) to open the file with a context manager
(with open(model_data_yaml_filepath, encoding="utf-8") as f:) and parse via
yaml.safe_load(f) instead of yaml.load(..., Loader=yaml.FullLoader); ensure any
other occurrences in the same module (and in the download_model_files() code
path) are updated to safe_load to avoid insecure parsing.

In `@README.md`:
- Around line 207-219: The recovery commands currently unconditionally reinstall
onnxruntime-gpu which will break ROCm users; update the README section that
follows the `[gpu]` vs `[rocm]` note to either split the example into two
backend-specific blocks or revert this paragraph to be CUDA-only: keep the
generic uninstall/cache steps (`pip uninstall torch onnxruntime` and `pip cache
purge`) but show separate reinstall commands — `pip install --force-reinstall
torch torchvision torchaudio` + `pip install --force-reinstall onnxruntime-gpu`
for CUDA/NVIDIA (`[gpu]`) and the appropriate ROCm runtime install for AMD
(`[rocm]`) — and clearly label each block so readers for `onnxruntime-gpu` are
only those using `[gpu]`.
- Around line 122-153: The README mixes CUDA and ROCm instructions; update the
ROCm section to consistently use the same ROCm wheel index, stop referencing
CUDA artifacts, and reflect what Separator.configure_rocm() expects: replace
both pip snippets to use a single ROCm wheel index (e.g.,
https://download.pytorch.org/whl/rocm7.2), change package install examples to
install PyTorch ROCm first and then pip install "audio-separator[rocm]" (or the
correct extra for ROCm), remove mentions of onnxruntime-gpu from the Required
ROCm Packages (keep onnxruntime and onnxruntime-rocm only), and change the
success log example to "ONNXruntime has ROCMExecutionProvider available,
enabling acceleration" to match Separator.configure_rocm().

---

Nitpick comments:
In `@debug_minimal.py`:
- Around line 26-44: The except blocks in the Torch and ONNX Runtime checks
currently append only str(e), losing stack traces; update the exception handlers
in the try/except around importing torch and onnxruntime (the blocks that append
"Error: {e}") to also capture and append the full traceback (e.g., via
traceback.format_exc()) so the debug output includes the stack and
import/provider context—import the traceback module at top if not present and
include the formatted traceback alongside the existing error message for both
exception handlers.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 4203a0a2-14b0-4617-ac2e-eb122c22e4ec

📥 Commits

Reviewing files that changed from the base of the PR and between 1dadc0b and 053532e.

📒 Files selected for processing (5)
  • README.md
  • audio_separator/separator/common_separator.py
  • audio_separator/separator/separator.py
  • debug_minimal.py
  • pyproject.toml

@coconutbird
Copy link

Hi Wyatt after a comprehensive review of your pull request Im going to need you to break your changes up into a series of smaller PRs because this is too much to review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants